"""Token base class."""

import copy
import typing
from typing import overload

from cki_lib import logger
from cki_lib import misc

from . import secrets
from . import utils

LOGGER = logger.get_logger(__name__)
TOKEN_REGISTRY = {}


class Token:
    """Token base class."""

    token_type: str
    clean_active_versions: int = 0

    def __init__(
        self,
        *,
        token_group: str,
        force: bool,
    ) -> None:
        """Token class."""
        self.all_token_data = secrets.read_secrets_file()
        self.token_group = token_group
        self.force = force
        self.supports_native_rotation = type(self)._rotate_token is not Token._rotate_token

        self.active_token_data = self.get_active_token_data(token_group)

    def full_token_name(self, token_version: str) -> str:
        """Return full token name."""
        return f'{self.token_group}/{token_version}' if token_version else self.token_group

    def get_active_token_data(
        self,
        token_group: str,
    ) -> dict[str, typing.Any]:
        """Return active matching tokens."""
        return {
            k: v for k, v in self.all_token_data.items()
            if (k == token_group or k.startswith(f'{token_group}/'))
            and misc.get_nested_key(v, 'meta/active')
        }

    def check_invariants(self) -> None:
        """Validate basic validity of all active token versions."""
        len_deployed_tokens = len(list(
            t for t in self.active_token_data.values() if misc.get_nested_key(t, 'meta/deployed')
        ))
        if not len_deployed_tokens:
            raise ValueError(f'No deployed token for {self.token_group}')
        if len_deployed_tokens > 1:
            raise ValueError(f'More than one deployed token for {self.token_group}')

        required_fields = (
            ('active', bool),
            ('deployed', bool),
            ('created_at', str),
        )
        for field, field_type in required_fields:
            if not all(
                isinstance(misc.get_nested_key(t, f'meta/{field}'), field_type)
                for t in self.active_token_data.values()
            ):
                raise ValueError(f'Missing {field} for {self.token_group}')

        if not all(
            misc.get_nested_key(t, 'meta/token_type') == self.token_type
            for t in self.active_token_data.values()
        ):
            raise ValueError(f'Inconsistent token types for {self.token_group}')

    def create(self, token_version: str) -> None:
        """Create a token, and update the secrets file."""
        token_name = self.full_token_name(token_version)
        print(f'Creating {token_name}')

        meta = secrets.secret(f'{token_name}#') | {
            'active': True,
            'created_at': misc.now_tz_utc().isoformat()
        }
        try:
            token_secret = self._create_token(token_version, meta)
        except Exception as err:
            secrets.edit(f"{token_name}#", meta | {"active": False, "error": str(err)})
            raise

        secrets.edit(f'{token_name}#', meta)
        if isinstance(token_secret, dict):
            secrets.edit(f'{token_name}:', token_secret)
        else:
            secrets.edit(f'{token_name}', token_secret)

    def _create_active(self) -> None:
        """Create another token version, and update the secrets file."""
        first_token = next(iter(self.active_token_data.values()))

        new_version = str(int(misc.now_tz_utc().timestamp()))
        new_name = self.full_token_name(new_version)
        print(f'Creating version of token {self.token_group} by recreating {new_name}')
        secrets.edit(f'{new_name}#', first_token['meta'] | {'deployed': False})

        self.create(new_version)

    def destroy(self, token_version: str) -> None:
        """Destroy a token, and update the secrets file."""
        token_name = self.full_token_name(token_version)
        print(f'Destroying {token_name}')

        meta = secrets.secret(f'{token_name}#')
        if meta['deployed']:
            raise ValueError(f'Not destroying deployed version {token_name}')

        self._destroy_token(token_version, meta)
        secrets.edit(f'{token_name}#', meta | {'active': False})

    def rotate(self, token_version: str) -> None:
        """Rotate a token, and update the secrets file."""
        token_name = self.full_token_name(token_version)
        print(f'Rotating {token_name}')

        meta = secrets.secret(f'{token_name}#')
        if meta['deployed']:
            raise ValueError(f'Not rotating deployed version {token_name}')

        new_version = str(int(misc.now_tz_utc().timestamp()))
        new_name = self.full_token_name(new_version)
        new_meta = copy.deepcopy(meta) | {
            'active': True,
            'created_at': misc.now_tz_utc().isoformat()
        }

        if self.supports_native_rotation:
            token_secret = self._rotate_token(token_version, new_version, meta, new_meta)
        else:
            token_secret = self._create_token(new_version, new_meta)
            self._destroy_token(token_version, meta)
        secrets.edit(f'{token_name}#', meta | {'active': False})
        secrets.edit(f'{new_name}#', new_meta)
        if isinstance(token_secret, dict):
            secrets.edit(f'{new_name}:', token_secret)
        else:
            secrets.edit(f'{new_name}', token_secret)

    def _rotate_active(self) -> None:
        """Create another token version by rotation, and update the secrets file."""
        token_name = next(k for k, v in self.active_token_data.items()
                          if not misc.get_nested_key(v, 'meta/deployed'))
        _, token_version = utils.split_token_name(token_name)

        self.rotate(token_version)

    def prepare(self) -> None:
        """Prepare the token group for rotation."""
        print(f'Preparing {self.token_group}')

        if needs_prepare := self.check_needs_prepare():
            LOGGER.debug(needs_prepare)
        else:
            if not self.force:
                print(f"Not preparing {self.token_group} without --force")
                return
            if not utils.confirm("Force prepare"):
                print(f"Not preparing {self.token_group}")
                return

        if len(self.active_token_data) < 2:
            self._create_active()
        else:
            self._rotate_active()

    def switch(self) -> None:
        """Switch the deployed tokens."""
        print(f'Switching {self.token_group}')

        if needs_prepare := self.check_needs_prepare():
            LOGGER.debug(needs_prepare)
            print(f'Not switching {self.token_group} because it needs preparation')
            return
        if not self.check_needs_rotate():
            if not self.force:
                print(f"Not switching {self.token_group} without --force")
                return
            if not utils.confirm("Force switch"):
                print(f"Not switching {self.token_group}")
                return

        token_active, token_deployed = list(self.active_token_data.items())
        if misc.get_nested_key(token_active[1], 'meta/deployed'):
            token_active, token_deployed = token_deployed, token_active

        secrets.edit(f'{token_active[0]}#deployed', True)
        secrets.edit(f'{token_deployed[0]}#deployed', False)

    def clean(self) -> None:
        """Clean the token group."""
        print(f'Clean {self.token_group}')

        if self.check_needs_clean():
            tokens = [
                k for k, v in self.active_token_data.items()
                if not misc.get_nested_key(v, 'meta/deployed')
            ]
            while tokens and len(tokens) + self.clean_active_versions > len(self.active_token_data):
                tokens.pop()
            for token_name in tokens:
                _, token_version = utils.split_token_name(token_name)
                self.destroy(token_version)

    def check_needs_prepare(self) -> str | None:
        """Check whether the token needs preparation before it can be rotated.

        Return None if the token group does not need rotation, or the reason rotation is needed.
        """
        if (count := len(self.active_token_data)) < 2:
            return f'Found {count} versions instead of 2'

        token_active, token_deployed = list(self.active_token_data.values())
        if misc.get_nested_key(token_active, 'meta/deployed'):
            token_active, token_deployed = token_deployed, token_active

        if (
            (active_created_at := utils.tz(token_active, "created_at"))
            and (deployed_created_at := utils.tz(token_deployed, "created_at"))
            and deployed_created_at > active_created_at
        ):
            return 'Deployed version is newer than undeployed version'

        LOGGER.debug(
            "No need to prepare %s: active_created_at > deployed_created_at (%s > %s)",
            self.token_group,
            active_created_at.isoformat() if active_created_at else None,
            deployed_created_at.isoformat() if deployed_created_at else None,
        )

        return None

    def check_needs_rotate(self) -> str | None:
        """Return whether the token needs rotation because of age/expiry.

        Return None if the token group does not need rotation, or the reason rotation is needed.
        """
        deployed_token = next(v for v in self.active_token_data.values()
                              if misc.get_nested_key(v, 'meta/deployed'))
        if utils.too_old(utils.tz(deployed_token, 'created_at'), utils.DEFAULT_INTERVAL):
            return 'Token too old'
        if (expires_at := utils.tz(deployed_token, "expires_at")) and utils.too_old(expires_at):
            return 'Token close to expiry'
        LOGGER.debug(
            "No need to rotate deployed token %s. (expires_at=%s)",
            self.token_group,
            expires_at.isoformat() if expires_at else None,
        )
        return None

    def check_needs_clean(self) -> str | None:
        """Check whether the token group needs cleaning from rotation.

        Return None if the token group is clean, or the reason cleaning is needed.
        """
        if (self.clean_active_versions and
                (count := len(self.active_token_data)) > self.clean_active_versions):
            return f'Found {count} active versions instead of {self.clean_active_versions}'

        return None

    def update(self, token_version: str) -> None:
        """Update meta data about a token version in the secrets file."""
        token_name = self.full_token_name(token_version)
        print(f'Updating {token_name}')

        meta = secrets.secret(f'{token_name}#')

        self._update_token(token_version, meta)
        secrets.edit(f'{token_name}#', meta)

    def validate(self, token_version: str) -> None:
        """Check validity of a token version."""
        token_name = self.full_token_name(token_version)
        print(f'Validating {token_name}')

        meta = secrets.secret(f'{token_name}#')

        self.check_invariants()
        self._validate_token(token_version, meta)

    def purge(self, token_version: str) -> None:
        """Remove an inactive token in the secrets file."""
        token_name = self.full_token_name(token_version)
        print(f'Purging {token_name}')

        meta = secrets.secret(f'{token_name}#')
        if meta['active']:
            raise ValueError(f'Not purging active version {token_name}')
        if meta['deployed']:
            raise ValueError(f'Not purging deployed version {token_name}')

        secrets.edit(f'{token_name}:', None)
        secrets.edit(f'{token_name}#', None)

    # Low-level interface

    def _create_token(
        self,
        token_version: str,
        meta: dict[str, typing.Any],
    ) -> str | dict[str, str]:
        """Create a token version and return secret.

        The meta data can be modified and is updated by the caller.
        """
        raise ValueError(f'No support to create {self.token_type} tokens')

    def _destroy_token(self, token_version: str, meta: dict[str, typing.Any]) -> None:
        """Destroy a token version.

        The meta data can be modified and is updated by the caller.
        """
        raise ValueError(f'No support to destroy {self.token_type} tokens')

    def _rotate_token(
        self,
        old_version: str,
        new_version: str,
        old_meta: dict[str, typing.Any],
        new_meta: dict[str, typing.Any],
    ) -> str | dict[str, str]:
        """Rotate a token version and return new secret.

        The original and new meta data can be modified and are updated by the caller.
        """
        raise ValueError(f'No support to rotate {self.token_type} tokens')

    # Misc helpers

    def _update_token(self, token_version: str, meta: dict[str, typing.Any]) -> None:
        """Update a token version.

        The meta data can be modified and is updated by the caller.
        """
        # by default, do nothing

    def _validate_token(self, token_version: str, meta: dict[str, typing.Any]) -> None:
        """Validate a token version."""
        # by default, do nothing


class register_token:
    # pylint: disable=too-few-public-methods, invalid-name
    """Decorator to register a token class."""

    def __init__(self, token_type: str):
        """Register a token class."""
        self.token_type = token_type

    def __call__(self, cls: type[Token]) -> type[Token]:
        """Register a token class."""
        TOKEN_REGISTRY[self.token_type] = cls
        cls.token_type = self.token_type
        return cls


@overload
def get_token(
    token_group_or_name: str,
    force: bool,
    raise_if_missing: typing.Literal[False],
) -> Token | None: ...


@overload
def get_token(
    token_group_or_name: str,
    force: bool,
    raise_if_missing: typing.Literal[True] = True,
) -> Token: ...


def get_token(
    token_group_or_name: str,
    force: bool,
    raise_if_missing: bool = True,
) -> Token | None:
    """Return a token object."""
    if not (cls := TOKEN_REGISTRY.get(secrets.secret(f'{token_group_or_name}[]#token_type')[0])):
        if raise_if_missing:
            raise ValueError(f'Unknown token type for {token_group_or_name}')
        return None
    token_group, _ = utils.split_token_name(token_group_or_name)
    return cls(token_group=token_group, force=force)
